diff --git a/sklearn/utils/_pprint.py b/sklearn/utils/_pprint.py
index 42f4d14cf..5627c4184 100644
--- a/sklearn/utils/_pprint.py
+++ b/sklearn/utils/_pprint.py
@@ -70,7 +70,7 @@ from collections import OrderedDict
 from ..base import BaseEstimator
 from .._config import get_config
 from . import is_scalar_nan
-
+import math
 
 class KeyValTuple(tuple):
     """Dummy class for correctly rendering key-value tuples from dicts."""
@@ -156,8 +156,8 @@ class _EstimatorPrettyPrinter(pprint.PrettyPrinter):
                  compact=False, indent_at_name=True,
                  n_max_elements_to_show=None):
         super().__init__(indent, width, depth, stream, compact=compact)
-        self._indent_at_name = indent_at_name
-        if self._indent_at_name:
+        self.indent_at_name = indent_at_name
+        if self.indent_at_name:
             self._indent_per_level = 1  # ignore indent param
         self._changed_only = get_config()['print_changed_only']
         # Max number of elements in a list, dict, tuple until we start using
@@ -169,12 +169,8 @@ class _EstimatorPrettyPrinter(pprint.PrettyPrinter):
         return _safe_repr(object, context, maxlevels, level,
                           changed_only=self._changed_only)
 
-    def _pprint_estimator(self, object, stream, indent, allowance, context,
-                          level):
+    def _pprint_estimator(self, object, stream, indent, allowance, context, level):
         stream.write(object.__class__.__name__ + '(')
-        if self._indent_at_name:
-            indent += len(object.__class__.__name__)
-
         if self._changed_only:
             params = _changed_params(object)
         else:
@@ -321,7 +317,7 @@ class _EstimatorPrettyPrinter(pprint.PrettyPrinter):
         self._format(v, stream, indent + len(rep) + len(middle), allowance,
                      context, level)
 
-    _dispatch = pprint.PrettyPrinter._dispatch
+    _dispatch = pprint.PrettyPrinter._dispatch.copy()
     _dispatch[BaseEstimator.__repr__] = _pprint_estimator
     _dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple
 
@@ -331,7 +327,7 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
     objects."""
     typ = type(object)
 
-    if typ in pprint._builtin_scalars:
+    if isinstance(object, (int, float, str, bytes, bool, type(None))):
         return repr(object), True, False
 
     r = getattr(typ, "__repr__", None)
@@ -342,7 +338,7 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
         if maxlevels and level >= maxlevels:
             return "{...}", False, objid in context
         if objid in context:
-            return pprint._recursion(object), False, True
+            return _recursion(object), False, True
         context[objid] = 1
         readable = True
         recursive = False
@@ -350,7 +346,7 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
         append = components.append
         level += 1
         saferepr = _safe_repr
-        items = sorted(object.items(), key=pprint._safe_tuple)
+        items = sorted(object.items(), key=_safe_tuple)
         for k, v in items:
             krepr, kreadable, krecur = saferepr(
                 k, context, maxlevels, level, changed_only=changed_only)
@@ -379,7 +375,7 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
         if maxlevels and level >= maxlevels:
             return format % "...", False, objid in context
         if objid in context:
-            return pprint._recursion(object), False, True
+            return _recursion(object), False, True
         context[objid] = 1
         readable = True
         recursive = False
@@ -402,7 +398,7 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
         if maxlevels and level >= maxlevels:
             return "{...}", False, objid in context
         if objid in context:
-            return pprint._recursion(object), False, True
+            return _recursion(object), False, True
         context[objid] = 1
         readable = True
         recursive = False
@@ -414,7 +410,7 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
         append = components.append
         level += 1
         saferepr = _safe_repr
-        items = sorted(params.items(), key=pprint._safe_tuple)
+        items = sorted(params.items(), key=_safe_tuple)
         for k, v in items:
             krepr, kreadable, krecur = saferepr(
                 k, context, maxlevels, level, changed_only=changed_only)
@@ -430,3 +426,22 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
 
     rep = repr(object)
     return rep, (rep and not rep.startswith('<')), False
+
+def _safe_tuple(t):
+    """Return a sortable version of the tuple for sorting purposes."""
+    return tuple(_safe_key(k) for k in t)
+
+def _safe_key(key):
+    """Return a sortable key for sorting purposes."""
+    if isinstance(key, (list, tuple, set)):
+        return (0, tuple(_safe_key(k) for k in key))
+    if key is None:
+        return (1,)
+    if isinstance(key, float) and math.isnan(key):
+        return (2,)
+    if isinstance(key, (int, float, str)):
+        return (3, key)
+    return (4, str(key))
+
+def _recursion(object):
+    return "<Recursion on {} with id={}>".format(type(object).__name__, id(object))
